# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Create the ImageNet2012 dataset from https://image-net.org/. """

import os
import shutil
import tempfile
from typing import Callable, Optional, Union, Tuple
from scipy import io

import mindspore.dataset.vision.c_transforms as transforms

from mindvision.dataset.download import label2index, read_dataset
from mindvision.dataset.meta import Dataset, ParseDataset
from mindvision.io.images import image_format
from mindvision.io.path import check_dir_exist, check_file_valid, load_json_file, save_json_file
from mindvision.check_param import Validator
from mindvision.engine.class_factory import ClassFactory, ModuleType

__all__ = ["ImageNet", "ParseImageNet"]


@ClassFactory.register(ModuleType.DATASET)
class ImageNet(Dataset):
    """
    ImageNet dataset loader, It has train directory and val directory.
    The directory structure of ImageNet dataset looks like:

    .imagenet/
    ├── train/  (1000 directories and 1281167 images)
    │  ├── n04347754/
    │  │   ├── 000001.jpg
    │  │   ├── 000002.jpg
    │  │   └── ....
    │  └── n04347756/
    │      ├── 000001.jpg
    │      ├── 000002.jpg
    │      └── ....
    └── val/   (1000 directories and 50000 images)
       ├── n04347754/
       │   ├── 000001.jpg
       │   ├── 000002.jpg
       │   └── ....
       └── n04347756/
           ├── 000001.jpg
           ├── 000002.jpg
           └── ....

    Args:
        path(string): Root directory of the ImageNet dataset or inference image.
        split(str): The dataset split, supports "train", "val", or "infer". Default: "train".
        num_parallel_workers(int, optional): Number of subprocess used to fetch the dataset in parallel.Default: None.
        transform(callable, optional):A function transform that takes in a image. Default: None.
        target_transform(callable, optional):A function transform that takes in a label. Default: None.
        batch_size(int): Batch size of dataset. Default: 64.
        repeat_num(int): The repeat num of dataset. Default: 1.
        shuffle(bool, optional): Whether or not to perform shuffle on the dataset. Default: None.
        num_shards(int, optional): Number of shards that the dataset will be divided into. Default: None.
        shard_id(int, optional): The shard ID within num_shards. Default: None.
        resize(int, tuple): The output size of the resized image. If size is an integer, the smaller edge of the
        image will be resized to this value with the same image aspect ratio. If size is a sequence of length 2,
        it should be (height, width). Default: 224.
        download(bool): Whether to download the dataset. Default: False.
    """

    def __init__(self,
                 path: str,
                 split: str = "train",
                 transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None,
                 batch_size: int = 64,
                 resize: Union[Tuple[int, int], int] = 224,
                 repeat_num: int = 1,
                 shuffle: Optional[bool] = None,
                 download: bool = False,
                 num_parallel_workers: int = 1,
                 num_shards: Optional[int] = None,
                 shard_id: Optional[int] = None):
        Validator.check_string(split, ["train", "val", "infer"], "split")
        load_data = read_dataset if split == "infer" else self.read_dataset

        super(ImageNet, self).__init__(path=path,
                                       split=split,
                                       load_data=load_data,
                                       transform=transform,
                                       target_transform=target_transform,
                                       batch_size=batch_size,
                                       repeat_num=repeat_num,
                                       resize=resize,
                                       shuffle=shuffle,
                                       num_parallel_workers=num_parallel_workers,
                                       num_shards=num_shards,
                                       shard_id=shard_id,
                                       download=download)

    @property
    def index2label(self):
        """
        Get the mapping of indexes and labels.
        """
        parse_imagenet = ParseImageNet(self.path)
        if not os.path.exists(os.path.join(parse_imagenet.path, parse_imagenet.meta_file)):
            parse_imagenet.parse_devkit()

        wind2class_name = load_json_file(os.path.join(parse_imagenet.path, parse_imagenet.meta_file))['wnid2class_name']
        wind2class_name = sorted(wind2class_name.items(), key=lambda x: x[0])
        mapping = {}

        for index, (_, class_name) in enumerate(wind2class_name):
            mapping[index] = class_name[0]

        return mapping

    def download_dataset(self):
        raise ValueError("Imagenet dataset download is not supported.")

    def default_transform(self):
        """Set the default transform for ImageNet dataset."""
        mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
        std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
        scale = 32

        if self.split == "train":
            # Define map operations for training dataset
            trans = [
                transforms.RandomCropDecodeResize(size=self.resize,
                                                  scale=(0.08, 1.0),
                                                  ratio=(0.75, 1.333)),
                transforms.RandomHorizontalFlip(prob=0.5),
                transforms.Normalize(mean=mean, std=std),
                transforms.HWC2CHW()
            ]
        else:
            # Define map operations for inference dataset
            trans = [
                transforms.Decode(),
                transforms.Resize(self.resize + scale),
                transforms.CenterCrop(self.resize),
                transforms.Normalize(mean=mean, std=std),
                transforms.HWC2CHW()
            ]
        return trans

    def read_dataset(self):
        """Read each image and its corresponding label from directory."""
        check_dir_exist(self.get_path)

        available_label = set()
        images_label, images_path = [], []
        label_to_idx = label2index(self.get_path)

        # Iterate each file in the path
        for label in label_to_idx.keys():
            for file_name in os.listdir(os.path.join(self.get_path, label)):
                if check_file_valid(file_name, image_format):
                    images_path.append(os.path.join(self.get_path, label, file_name))
                    images_label.append(label_to_idx[label])
                    if label not in available_label:
                        available_label.add(label)

        empty_label = set(label_to_idx.keys()) - available_label

        if empty_label:
            raise ValueError(f"Found invalid file for the label {','.join(empty_label)}.")

        return images_path, images_label


class ParseImageNet(ParseDataset):
    """
    Parse ImageNet dataset and generated the json file(file name:imagenet_meta.json).
    the ImageNet dataset look like:

        .imagenet/
        ├── ILSVRC2012_devkit_t12.tar.gz
        ├── ILSVRC2012_img_train.tar
        └── ILSVRC2012_img_val.tar

        or:

        .imagenet/
        ├── ILSVRC2012_devkit_t12.tar.gz
        ├── train/
        └── val/

    Args:
        path(string): The root path of ImageNet2012 dataset which must include ILSVRC2012_devkit_t12.tar.gz
            and train/val compressed package or directory.

    Examples:
        >>> TODO: @wushuiqin show how to use these class.
    """

    # key for imagenet dataset name, value is tuple for tar name and md5.
    imagenet_dict = {
        'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'),
        'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'),
        'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf')
    }
    devkit_file = ["meta.mat", "ILSVRC2012_validation_ground_truth.txt"]
    meta_file = "imagenet_meta.json"

    def __parse_meta_mat(self, devkit_path):
        """Parse the mat file(meta.mat)."""
        metafile = os.path.join(devkit_path, "data", self.devkit_file[0])
        meta = io.loadmat(metafile, squeeze_me=True)['synsets']

        nums_children = list(zip(*meta))[4]
        meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]

        idcs, wnids, classes = list(zip(*meta))[:3]
        clssname = [tuple(clss.split(', ')) for clss in classes]
        idx2wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
        wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
        return idx2wnid, wnid2class

    def __parse_groundtruth(self, devkit_path):
        """Parse ILSVRC2012_validation_ground_truth.txt."""
        val_gt = os.path.join(devkit_path, "data", self.devkit_file[1])
        with open(val_gt, "r") as f:
            val_idx2image = f.readlines()
        return [int(i) for i in val_idx2image]

    def __parse_train(self):
        """Parse the train images archive of the ImageNet2012 classification dataset."""
        file_name = self.imagenet_dict["train"][0]
        md5 = self.imagenet_dict["train"][1]

        if not self.download.check_md5(os.path.join(self.path, file_name), md5):
            raise RuntimeError(f"The imagenet dataset {file_name} is not exist or md5 check failed.")

        # Decompressing  ILSVRC2012_img_train.tar
        train_path = os.path.join(self.path, "train")
        self.download.extract_archive(os.path.join(self.path, file_name), train_path)

        archives = [os.path.join(train_path, archive) for archive in os.listdir(train_path)]
        for archive in archives:
            self.download.extract_archive(archive, os.path.splitext(archive)[0])
            os.remove(archive)

    def __parse_val(self):
        """Parse the validation images archive of the ImageNet2012 classification dataset."""
        file_name = self.imagenet_dict["val"][0]
        md5 = self.imagenet_dict["val"][1]

        val_wnids = load_json_file(os.path.join(self.path, self.meta_file))["val_wnids"]
        if not self.download.check_md5(os.path.join(self.path, file_name), md5):
            raise RuntimeError(f"The imagenet dataset {file_name} is not exist or md5 check failed.")

        # Decompressing  ILSVRC2012_img_val.tar
        val_path = os.path.join(self.path, "val")
        self.download.extract_archive(os.path.join(self.path, file_name), val_path)

        # Move the image file to its corresponding class directory
        img_file = sorted([os.path.join(val_path, i) for i in os.listdir(val_path)])
        for i in set(val_wnids):
            os.mkdir(os.path.join(val_path, i))
        for i, j in zip(val_wnids, img_file):
            shutil.move(j, os.path.join(val_path, i, os.path.basename(j)))

    def parse_devkit(self):
        """Parse the devkit archive of the ImageNet2012 classification dataset and save meta info in json file."""
        file_name = self.imagenet_dict["devkit"][0]
        md5 = self.imagenet_dict["devkit"][1]

        if not self.download.check_md5(os.path.join(self.path, file_name), md5):
            raise RuntimeError(f"The imagenet dataset {file_name} is not exist or md5 check failed.")

        with tempfile.TemporaryDirectory() as temp_dir:
            # Decompressing  ILSVRC2012_devkit_t12.tar.gz
            self.download.extract_archive(os.path.join(self.path, file_name), temp_dir)
            devkit_path = os.path.join(temp_dir, "ILSVRC2012_devkit_t12")
            idx2wnid, wnid2class = self.__parse_meta_mat(devkit_path)
            val_idcs = self.__parse_groundtruth(devkit_path)
            val_wnids = [idx2wnid[idx] for idx in val_idcs]

            # Generating imagenet_meta.json which saved the values of wnid2class and val_wnids
            dict_json = {"wnid2class": wnid2class, "val_wnids": val_wnids}
            save_json_file(os.path.join(self.path, self.meta_file), dict_json)

    def parse_dataset(self):
        """Parse the devkit archives of ImageNet dataset."""
        if not os.path.exists(os.path.join(self.path, self.meta_file)):
            self.parse_devkit()
        if not os.path.isdir(os.path.join(self.path, "train")):
            self.__parse_train()
        if not os.path.isdir(os.path.join(self.path, "val")):
            self.__parse_val()
